import os
import pandas as pd
import transformers
from transformers import AutoTokenizer
from sklearn.model_selection import StratifiedGroupKFold
import torch
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_metric
import evaluate
import torch.nn as nn
from scipy.stats import pearsonr

# todo 调用不同的优化器
# 批次输入模型，得到输出，可以处理单个或者多个都没有问题

os.environ["WANDB_DISABLED"] = "true"

# print(transformers.__version__)
# 调用评价指标
metric = evaluate.load('glue', 'stsb')

# metric = load_metric('glue', 'stsb')

# 调用模型
# model_checkpoint = 'microsoft/deberta-v3-small'
model_checkpoint = '/home/zxy/Projects/a_pretrainmodel/microsoft/deberta-v3-small'
# model_checkpoint = 'bert-for-patents'
# model_checkpoint = '/home/zxy/Projects/a_pretrainmodel/bert-for-patents'
# model_checkpoint = 'anferico/bert-for-patents'
# model_checkpoint = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# 加载训练数据，进一步处理
df = pd.read_csv('usp_data/train.csv')
df = df.sample(100)
df_title = pd.read_csv('usp_data/titles.csv')
df = df.merge(df_title, how='left', left_on='context', right_on='code')
df = df[['id', 'anchor', 'target', 'context', 'score', 'title']]
# 将数据分为５折，每次将４个作为训练集，将剩余１个作为测试集
kf = StratifiedGroupKFold(n_splits=2, shuffle=True, random_state=42)
df['fold'] = -1
for f, (t_, v_) in enumerate(kf.split(X=df, y=df['anchor'], groups=df['anchor'])):
    df.loc[v_, 'fold'] = f

df['input'] = df['anchor'] + tokenizer.sep_token + df['title'].apply(str.lower)

if not os.path.isdir('backup'):
    os.mkdir('backup/')
df.to_excel('backup/test.xlsx', index=False)


# 定义批次加载数据的类
class TrainDataset(Dataset):
    def __init__(self, df):
        self.inputs = df['input'].values.astype(str)
        self.targets = df['target'].values.astype(str)
        self.label = df['score'].values

    def __getitem__(self, item):
        inputs = self.inputs[item]
        targets = self.targets[item]
        label = self.label[item]
        inputs = tokenizer(inputs, targets, max_length=64, padding='max_length', truncation=True)
        tmp = {**inputs, 'labels': torch.as_tensor(label, dtype=torch.float)}
        return {**inputs, 'labels': torch.as_tensor(label, dtype=torch.float)}

    def __len__(self):
        return len(self.inputs)

# 实例化训练模型
num_labels = 1
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
# outputs=model(**inputs)
metric_name = "pearson"
model_name = model_checkpoint.split("/")[-1]
batch_size = 5
args = TrainingArguments(
    f"{model_name}-finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    save_total_limit=1,
)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return metric.compute(predictions=predictions, references=labels)


train_dataset = TrainDataset(df[df['fold'] != 0])
val_dataset = TrainDataset(df[df['fold'] == 0])
# train_dataset = TrainDataset(df.sample(100))
# val_dataset = TrainDataset(df.sample(50))
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,

)

trainer.train()

if not os.path.isdir('output/'):
    os.mkdir('output/')
model_path = os.path.join('output', "%s_%d.pth" % (model_checkpoint.split('/')[-1], args.num_train_epochs))
torch.save(model.state_dict(), model_path)

trainer.evaluate()

# todo 模型预测
# predictions = trainer.predict(tokenized_datasets["validation"])
# print(predictions.predictions.shape, predictions.label_ids.shape)

# 自动寻找优化的代码
# def model_init():
#     return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
#
#
# trainer = Trainer(
#     model_init=model_init,
#     args=args,
#     train_dataset=train_dataset,
#     eval_dataset=val_dataset,
#     tokenizer=tokenizer,
#     compute_metrics=compute_metrics
# )
#
# best_run = trainer.hyperparameter_search(n_trials=10, direction="maximize")
# for n, v in best_run.hyperparameters.items():
#     setattr(trainer.args, n, v)
# trainer.train()
